Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713cspades wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
50da1dc to
925d022
Compare
Greptile SummaryThis PR adds Key changes:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant TEModule as TE Module (Linear/LayerNormLinear/etc.)
participant Base as TransformerEngineBaseModule
participant Dist as distributed.py
participant FSDP2 as fully_shard (FSDP2)
participant DCP as Torch DCP
User->>TEModule: __init__(tp_size, tp_mesh, weight_mesh)
TEModule->>TEModule: set_device_mesh(tp_mesh, weight_mesh)
TEModule->>Dist: _convert_param_to_dtensor_param(param, tp_mesh, Shard/Replicate)
Dist-->>TEModule: DTensor(local_param, placement)
TEModule->>TEModule: reset_parameters(defer_init=True/False)
User->>FSDP2: fully_shard(model, mesh[dp_dims])
Note over FSDP2: Detects DTensor Shard(dim=0) on DP-matching dim<br/>→ uses _StridedShard placement for FSDP-TP
User->>TEModule: reset_parameters() [meta device only]
Base->>Base: quantize param (FP8)
Base->>Dist: _convert_param_to_dtensor_param(fp8_param, dtensor.device_mesh, ...)
Dist-->>Base: DTensor(fp8_param, same placement)
rect rgb(200, 230, 255)
Note over TEModule,FSDP2: Training Forward Pass
FSDP2->>FSDP2: all-gather sharded DTensor weight
Note over FSDP2: TP-sharded DTensor remains after all-gather
TEModule->>Dist: _extract_trainable_tensor_from_dtensor(dtensor)
Dist-->>TEModule: local Tensor (identity-preserved via _ToLocalIdentity)
TEModule->>TEModule: C++ kernel (plain Tensor)
end
rect rgb(255, 230, 200)
Note over TEModule,FSDP2: DCP Checkpoint Save/Load
User->>DCP: save({"app": AppState(model, optimizer)})
Note over DCP: AppState.state_dict() evicts _extra_state,<br/>clears empty optimizer states for empty params
DCP-->>User: checkpoint written
User->>DCP: load({"app": AppState(model, optimizer)})
Note over DCP: set_state_dict(strict=False) ignores<br/>_extra_state and empty optimizer entries
DCP-->>User: checkpoint restored
end
|
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
Outdated
Show resolved
Hide resolved
4ec2947 to
dbb9d14
Compare
fcdd5bd to
c912f5b
Compare
bc82f02 to
267f1df
Compare
|
/te-ci L1 pytorch |
f0b3cae to
af7362a
Compare
5d473b8 to
9435382
Compare
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
…ess. Signed-off-by: Cory Ye <cye@nvidia.com>
… are still model parity tested. Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
9435382 to
15df86f
Compare
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Summary
(H/F)SDP2 x TPstrided sharding, andDTensorFP8 parameters for Torch DCP checkpointing, across allTransformerEngineBaseModule(s).GroupedLinear, pending FSDP2 standalone pipe-cleaning. All other modules undertransformer_engine.pytorch.modulesare supported.FusibleOperationsupport is also a WIP, except forLayerNormorRMSNormwhich are TE modules.DTensor-based TP when unified by Torch DCP! In the Llama3 recipe, we useDTensor-based TP on thetorch.nn.Embedding, TransformerEngine-based TP on the LM head, and weight-tie the LM head to thetorch.nn.Embedding, which is why we do not need to callset_device_meshfor the LM head!Usage / Documentation
(
tp_meshandweight_meshcan also be passed inTEModule.__init__.)Details
DTensor Lifecycle in TransformerEngine
__init__metadevice with the appropriatetp_sizeand TP sharding strategy, e.g.parallel_modeandsequence_parallel.TransformerEngineModule.set_device_mesh(tp_mesh, weight_mesh)DTensorwith appropriate TPplacement(s) based on the TP sharding strategy specified in__init__, usingtransformer_engine.pytorch.distributed._convert_param_to_dtensor_param.tp_meshis a 1-DDeviceMeshcontaining the TPProcessGroupthat will be registered with the TransformerEngine module.weight_meshis the 1-DDeviceMeshcontaining theProcessGroupthat shards TransformerEngine module weights, the flattened combination of groups such as FSDP and TP. Specifically, it excludes non-weight groups such as DP-Replicate when using HSDP or HSDP-TP and is mainly required for per-Tensor scaling recipes likeFloat8CurrentScaling.fully_shard(which responds to the TP placements) and prior toreset_parameters(defer_init=False), which quantizes parameters.__init__(tp_mesh, weight_mesh)for supported TransformerEngine modules.fully_shardshards the TransformerEngine model with FSDP2.fully_shardencounters TP sharding ondim=0, it will use a_StridedShardfor DP. Put simply, this "pre-shards" the data prior to sharding on the current placement, followed by concatenating the pre-shards to get strided shards that will be re-sharded by the next placement. This effectively reverses the sharding order when processing the placements from left-to-right, and distributes shards as if we sharded on TP first, then FSDP, as required, even though DP appears before TP in theDeviceMeshandDTensor.placements. (SeeAppendixfor visualization of this sharding strategy.)reset_parametersis called if using meta device initialization.fully_shard. (Note that this essentially shares the same properties as the compute weight besides shape, and supporting tools such asFusedAdammust be used to properly handle high-precision main weights.)Tensoris actually a TP-shardedDTensor, which deviates from the original FSDP2 paradigm where the all-gatheredTensoris fully-unsharded and theDTensorwrapping is discarded. To support theseDTensorcompute weights in TransformerEngine modules, we utilizetransformer_engine.pytorch.distributed._extract_trainable_tensor_from_dtensorto localize theDTensorand also inheritrequires_gradattribute from theDTensorparameter as the localTensorhas this un-set duringDTensor.from_local(Tensor)for FP8 parameters specifically!Tensorgradient is converted toDTensorand attached to theDTensor.gradattribute. Handled by DTensor <> Tensor Autograd conversion functions, and in the case ofFusibleOperation, casted during the backward implementation.QuantizedTensorStorageNone, we senduntyped_storage()to a default 1-byte storage that unblocks DCP checkpoint loading assertions using this as a definition for "emptiness". This is because a storage of 0 bytes is adata_ptr() = nullptrand breaks DCP.untyped_storageis not used anywhere in TransformerEngine, it may break code that usesStorageto figure out if a Tensor is empty or not, as nowQuantizedTensorstorage will always be a 1-byte storage even when both row and column data are not set. Those checks would instead need to compare the storage size against 1 byte instead of 0 bytes.Bugs
"shard"was the presumed weight sharding sub-mesh in theDTensor.device_mesh. Now, users can precisely specify their own custom weight-shardingDeviceMeshfor per-tensoramax_reduction_groupvia theset_device_mesh(weight_mesh)API.TransformerEngineBaseModule:self.quantizers = {"scaling_fwd": [], "scaling_bwd": []}Testing
num_zerostest failure that is common to bothmainandcspades:cye/fsdp2-tp-dcpso we can assume it is not associated to my change: https://github.com/NVIDIA/Megatron-LM/actions/runs/22637904520/job/65636890955?pr=3661 (TransformerEnginemain)mainvs.cspades:cye/fsdp2-tp-dcpwith Megatron-LMmainon PyTorch25.11DelayedScalinghas DCP save/load disparity issues, i.e. on the scale of+/-1to theuint8parameter checkpoint!Appendix
_StridedShard- Using FSDP2 x TP Strided-ShardingWhen
redistribute'ing a global DTensor to(_StridedShard(dim=0, sf=2), Shard(dim=0)),DTensorwill perform the following steps:Shardplacements to the right of_StridedShard. (In the above example, since TP=2, the factor is 2.)[0 1 2 3 4 5 6 7] -> [0 1 2 3] and [4 5 6 7].fully_shard, this has already been done via initializing the TransformerEngine module with TP and calling_convert_param_to_dtensor_param!_StridedShard.[0] [1] [2] [3]and[4] [5] [6] [7][0 4] [1 5] [2 6] [3 7], which are assigned to the_StridedShardranks.[0 1] [2 3] [4 5] [6 7]!Shardplacement.[0] [4]/[1] [5]/[2] [6]/[3] [7], which are assigned to theShardranks.[0] [1]/[2] [3]/[4] [5]/[6] [7]!PyTorch also supports the inverse / un-sharding of this
redistribute, which is literally the inverse of these simple operations! (Though things get a bit more complicated with un-even shards from odd-numbered dimension sizes.)Type of change
Changes
Please list the changes introduced in this PR:
Checklist: